import numpy as np
import random
from sklearn import metrics
from statsmodels.distributions.empirical_distribution import ECDF
from sklearn.covariance import ShrunkCovariance
import matplotlib.pyplot as plt

### Loading pre-trained data
lwp_clean = np.loadtxt('./clean_features.csv', delimiter=",", dtype = 'float')
lwp_bd = np.loadtxt('./bd_features.csv',  delimiter=",", dtype = 'float')


def covar_estimation(data, shrinkage):
    # spanning a range of possible shrinkage coefficient values
    #shrinkages = np.logspace(-2, 0, 30)

    # GridSearch for an optimal shrinkage coefficient
    #tuned_parameters = [{"shrinkage": shrinkages}]
    #cv = GridSearchCV(ShrunkCovariance(), tuned_parameters)
    #cv.fit(data)
    #cv.best_estimator_.shrinkage
    best_cov = ShrunkCovariance(shrinkage = shrinkage).fit(data)
    #print(cv.best_estimator_.shrinkage)
    return best_cov.covariance_

def Mdistance(x, mean, cov):
    
    """
    Stable M distance calculation
    
    """
    
    
    ## chols decom
    L = np.linalg.cholesky(cov)
    
    '''
    solve for Ly = x, need to switch to Least Square solution for NonPSD case
    
    '''
    mean_deducted = x - mean
    
    y = np.linalg.solve(L, mean_deducted.T)
    
    return np.linalg.norm(y.T, axis = 1, ord = 2)


def l1(data, mean):
    return np.sum(np.abs(data - mean), axis = 1)



def output_cdfs(clean_features, backdoor_features, shrink, vali_size, test_size, metric = 'M', nlp = False):
    
    ####
    #clean_dis = []
    #bd_dis = []
    np.random.seed(67)
    shuffled_index = np.random.permutation(np.arange(0,4000))
    
    if nlp:
        np.random.seed(67)
        shuffled_index = np.random.permutation(np.arange(0,1600))
    

    vali_data = (clean_features[shuffled_index])[0:vali_size]
    vali_mean = np.mean(vali_data, axis = 0)
    vali_cov = covar_estimation(vali_data, shrink)

    
    if metric == 'l1':
            
            clean_dis = l1(clean_features[shuffled_index][0:vali_size + test_size], vali_mean)
            if nlp:
                bd_dis =  l1(backdoor_features[0:900], vali_mean)  
            else:
                bd_dis =  l1(backdoor_features[shuffled_index][0:vali_size + test_size], vali_mean)  
    
    if metric == 'M':

        clean_dis = Mdistance(clean_features[shuffled_index][0:vali_size + test_size], vali_mean, vali_cov)

        if nlp:
            bd_dis =  Mdistance(backdoor_features[0:900], vali_mean, vali_cov)  
        else:
            bd_dis =  Mdistance(backdoor_features[shuffled_index][0:vali_size + test_size], vali_mean, vali_cov)  
        
    if metric == 'Euc':
         
            clean_dis = np.linalg.norm(clean_features[shuffled_index][0:vali_size + test_size] - vali_mean, axis = 1, ord = 2)
            if nlp:
                bd_dis = np.linalg.norm(backdoor_features[0:900] - vali_mean, axis = 1, ord = 2)  
            else:
                bd_dis = np.linalg.norm(backdoor_features[shuffled_index][0:vali_size + test_size] - vali_mean, axis = 1, ord = 2)  
  
    
    clean_ecdf = ECDF(clean_dis[0:vali_size])
    bd_ecdf = ECDF(bd_dis[0:vali_size])

    return clean_ecdf, bd_ecdf, clean_dis, bd_dis



lwp_clean_ecdf, lwp_bd_ecdf, lwp_clean_dism, lwp_bd_dism = output_cdfs(lwp_clean, lwp_bd, 0.7, 1000, 600, 
                                                                     'M', nlp = True)


lwp_clean_ecdf, lwp_bd_ecdf, lwp_clean_dis1, lwp_bd_dis1 = output_cdfs(lwp_clean, lwp_bd, 0.7, 1000, 600, 
                                                                     'l1', nlp = True)

lwp_clean_ecdf, lwp_bd_ecdf, lwp_clean_dis2, lwp_bd_dis2 = output_cdfs(lwp_clean, lwp_bd, 0.7, 1000, 600, 
                                                                     'Euc', nlp = True)

clean_dis_dic = [lwp_clean_dism, lwp_clean_dis2, lwp_clean_dis1]
bd_dis_dic = [lwp_bd_dism, lwp_bd_dis2, lwp_bd_dis1]
name = ['SCM (Our Method)', r'$\ell_2$', 'MAD']


linestyles = ['solid', 'dotted', 'dashed', 'dashdot']
fig, ax = plt.subplots(figsize=(10, 10))
y_true = np.zeros(2500)
y_true[1600 : 2500] = 1
for _ , (clean,bd) in enumerate(zip(clean_dis_dic,bd_dis_dic)):
    fpr, tpr, thresholds = metrics.roc_curve(y_true, np.concatenate((clean, bd)))
    score = metrics.roc_auc_score(y_true, np.concatenate((clean, bd)))
    plt.plot(fpr, tpr, label = '{}, AUC Score:{}'.format(name[_], np.around(score,3)), linestyle = linestyles[_], linewidth = 5)
 
    

    
plt.legend(loc = 'lower left', fontsize = 27, shadow = True)    
plt.xlabel('FPR', fontsize =  31)
plt.ylabel('TPR', fontsize =  31)
plt.xticks(fontsize = 22)
plt.yticks(fontsize = 22)
plt.title('Attack: LWP,' + ' Dataset: SST2 (NLP)', fontsize =  32)
plt.show()
